import json
import re
import math, ast
from typing import List, Dict, Tuple, Any

class Base_Model:
    def __init__(self, res_path: str, raw_path: str):
        self.res_path = res_path
        self.raw_path = raw_path
        self.res_data = self._read_json(res_path)
        self.raw_data = self._read_json(raw_path)
        self.raw_index = self._build_index(self.raw_data)

    @staticmethod
    def _read_json(path: str) -> List[Dict[str, Any]]:
        with open(path, 'r', encoding='utf-8') as f:
            return json.load(f)

   
    @staticmethod
    def _build_index(data: List[Dict[str, Any]]) -> Dict[Tuple[str, int, int], Dict[str, Any]]:
        return {
            (item['image_path'][0], item['episode_id'], item['step_id']): item
            for item in data
        }

    def compute_distances(self) -> List[Dict[str, Any]]:
        results = []
        for item in self.res_data['detailed_results']:
            key = (item['image_path'][0], item['episode_id'], item['step_id'])
            if key in self.raw_index:
                raw_item = self.raw_index[key]
                x1, y1 = self._extract_point(item['predicted_action'], True, item)
                x2, y2 = self._extract_point(raw_item['predicted_action'], False, item)

                if None not in (x1, y1, x2, y2):
                    dist = math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
                    results.append({
                        "image_path": item['image_path'][0],
                        "episode_id": item['episode_id'],
                        "step_id": item['step_id'],
                        "predicted_action_res": (x1, y1),
                        "predicted_action_raw": (x2, y2),
                        "distance": dist
                    })
        return results
    
    def percentage_below_threshold(self, threshold: float) -> float:
        distances = [r["distance"] for r in self.compute_distances()]
        if not distances:
            return 0.0
        count_below = sum(d <= threshold for d in distances)
        return (count_below / len(self.res_data['detailed_results'])) * 100
    
    def _get_ris(self):
        with open(self.res_path, 'rb') as file:
            visual_mask_res = json.load(file)

        count = 0
        for item in visual_mask_res['detailed_results']:
            if item['predicted_action_type'] == 4 or item['predicted_action_type'] == 5 or item['predicted_action_type'] == 6 or item['predicted_action_type'] == 9 or item['predicted_action_type'] == 8:
                count += 1
        return count/len(visual_mask_res['detailed_results'])


class OS_ATLAS(Base_Model):
    def __init__(self, res_path: str, raw_path: str):
        super().__init__(res_path, raw_path) 

    @staticmethod
    def _extract_point(action, flag, item):
        if "[[" in action:
            start = action.find("[[") + 2
            end = action.find("]]")
        elif "<point>" in action:
            start = action.find("[") + 1
            end = action.find("]")
        else:
            if flag:
                return -3000, -3000
            else:
                return -2000, -2000
            
        try:
            coordinates = action[start:end].split(",")
            x, y = float(coordinates[0]), float(coordinates[1])
            if not flag:
                image_size = item['image_size']
                x, y = x/image_size[0]*1000, y/image_size[1]*1000
            return x, y
        except:
            if flag:
                return -3000, -3000
            else:
                return -2000, -2000

class OS_GENESIS(Base_Model):
    def __init__(self, res_path: str, raw_path: str):
        super().__init__(res_path, raw_path) 

    @staticmethod
    def _extract_point(content, flag):
        try:
            x, y = content['POINT']
        except:
            print("extract coordinates failure")
            if flag:
                return -3000, -3000
            else:
                return -2000, -2000
        return float(x), float(y)

class UI_TARS(Base_Model):
    def __init__(self, res_path: str, raw_path: str):
        super().__init__(res_path, raw_path) 

    @staticmethod
    def _extract_point(content, flag):
        import re
        try:
            match = re.search(r"\(\s*([-+]?\d*\.?\d+)\s*,\s*([-+]?\d*\.?\d+)\s*\)", content)
            if match:
                coords = [float(match.group(1)), float(match.group(2))]
                return coords[0], coords[1]
            else:
                return None, None
        except:
            return None, None

class GUI_Odyssey(Base_Model):
    def __init__(self, res_path: str, raw_path: str):
        super().__init__(res_path, raw_path) 

    @staticmethod
    def _extract_point(content, flag):
        try:
            x, y = content['POINT']
        except:
            if flag:
                return (-3000, -3000)
            else:
                return (-2000, -2000)
        return (x, y)

class Aguvis(Base_Model):
    def __init__(self, res_path: str, raw_path: str):
        super().__init__(res_path, raw_path) 

    @staticmethod
    def _extract_point(content, flag):
        try:
            x, y = content['POINT']
        except:
            print("extract coordinates failure")
            if flag:
                return -3000, -3000
            else:
                return 2000, 2000
        return (x, y)
    
class GUI_R1(Base_Model):
    def __init__(self, res_path: str, raw_path: str):
        super().__init__(res_path, raw_path) 

    @staticmethod
    def _extract_point(content, flag):
        import ast
        try:
            data = ast.literal_eval(content)  
            return data[0]['point'][0], data[0]['point'][1]
        except Exception as e:
            print(f"解析失败: {e}")
            print(content)
            if flag:
                return -3000, -3000
            return 2000, 2000

class Agent_CPM(Base_Model):
    def __init__(self, res_path: str, raw_path: str):
        super().__init__(res_path, raw_path) 
        
    @staticmethod
    def _extract_point(content, flag):
        
        try:
            x, y = ast.literal_eval(content)['POINT']
        except:
            print("extract coordinates failure")
            if flag:
                return (-3000, -3000)
            else:
                return (-2000, -2000)
        return (x, y)
    
class GUI_Owl(Base_Model):
    def __init__(self, res_path: str, raw_path: str):
        super().__init__(res_path, raw_path) 
        
    @staticmethod
    def _extract_point(content, flag, item):
        try:
            x, y = content['coordinate']
            if not flag:
                image_size = item['image_size']
                x, y = x/1000*image_size[0], y/1000*image_size[1]
                return x, y
            return x, y

        except Exception as e:
            return None, None
  
if __name__ == "__main__":
    calc = GUI_Owl(
        "/Agent_ScanKit/results/visual_edit/low/UI-TARS-72B-DPO.json",
        "/Agent_ScanKit/datasets/json/visual_mask/low/UI-TARS-72B-DPO_raw.json"
    )

    distances = calc.compute_distances()
    # print("阈值 0 像素以内的比例:", calc.percentage_below_threshold(0), "%")
    # print("阈值 10 像素以内的比例:", calc.percentage_below_threshold(10), "%")
    print("vmc_50:", calc.percentage_below_threshold(50), "%")

    print("RS:", calc._get_ris()*100, "%")

    
